
clear
clc

cd 'E:\data_replication'

i_k = 2000

% Import Data and Estimates
%==========================================================================

% Import Demand Parameter Estimates
%--------------------------------------------------------------------------

load('estimation/5_supply_side/main_data_supply.mat');
clearvars -except i_k index_price_coeff X_MPEC_opt_summary Y_summary K K_A

k = index_price_coeff(i_k,1);    
X_MPEC_opt = X_MPEC_opt_summary{k};
Y_k = Y_summary{k};
alpha_constant = X_MPEC_opt(1,1);
alpha_distance = X_MPEC_opt(2,1);
alpha_population = X_MPEC_opt(3,1);
alpha_price = X_MPEC_opt(4,1);
alpha_1 = X_MPEC_opt(12,1);
alpha_FE = X_MPEC_opt((3*(K_A+1) + 1):(3*(K_A+1) + (K - K_A)), 1);


% Import Cost Estimates
%--------------------------------------------------------------------------

filename_costs_matlab = sprintf('robustness/firm_heterogeneity/cost_estimates/costs_%i_k.mat', i_k);  
load(filename_costs_matlab)

numProdsTotal = size(data_save,1);

data_k = [data_save.Year, data_save.Quarter, data_save.Declarant, data_save.Partner];
declarant = data_save.Declarant;
partner = data_save.Partner;
price_k = data_save.price_k;
share_k = data_save.share_k;
distance_k = data_save.distance_k;
population_log_k = data_save.population_log_k;
Y_k = Y_summary{k};
FE_k = [data_save.FE_k1, data_save.FE_k2, data_save.FE_k3, data_save.FE_k4];
market_id_k = data_save.market_id;
T = market_id_k(end,1);
nn = size(Y_k,2);

q_k = data_save.q_k;
mc_k = data_save.mc_opt;
m1 = data_save.m1;
m0 = log(mc_k) ./ (q_k.^m1);

prodsMarket = zeros(T,1);
for t = 1:T
prodsMarket_temp = sum(market_id_k==t);
prodsMarket(t,1) = prodsMarket_temp(1,1);
end


marketStarts = zeros(T,1);
marketEnds = zeros(T,1);
marketStarts(1) = 1;
marketEnds(1) = prodsMarket(1);
for t=2:T
    marketStarts(t) = marketEnds(t-1) + 1;
    marketEnds(t) = marketStarts(t) + prodsMarket(t) - 1;
end

marketForProducts = zeros(numProdsTotal,1);
for t=1:T
marketForProducts(marketStarts(t):marketEnds(t)) = t;
end

f_k0 = data_save.f_opt;


% Import Fixed Cost
%--------------------------------------------------------------------------

filename_f_opt = sprintf('robustness/firm_heterogeneity/cost_estimates/data_f_opt_%i_k.mat', i_k);  
load(filename_f_opt);
f_k1 = data_f_opt.f_opt(marketForProducts);

f_k = 0.5 * f_k0 + 0.5 * f_k1;

% Compute Market Shares
%--------------------------------------------------------------------------

delta_k =  q_k + log(price_k)*alpha_price + alpha_population*population_log_k + alpha_distance*distance_k;

for t = 1:T   
Y_market = Y_k(t,:);
expmu(marketStarts(t):marketEnds(t),:) = exp(log(price_k(marketStarts(t):marketEnds(t),:))*alpha_1*Y_market);
end
expmeanval = exp(delta_k);
oo = ones(1,nn);                  
sharesum = sparse(zeros(T,numProdsTotal));  % used to create denominators in logit predicted shares (i.e. sums numerators)
for t = 1:T
    sharesum(t,marketStarts(t):marketEnds(t)) = 1;
end
numer = (expmeanval*oo ).*expmu;        
sum1 = sharesum*numer;
%sum11 = 1./(1+sum1);                                                      % Will result in share_k * 0.9 = EstShare_true
sum11 = 1./(0+sum1);                                                       % Will not result in share_k = EstShare_true
denom1 = sum11(marketForProducts,:);    
simShare_true = numer.*denom1;              
EstShare_true = mean(simShare_true,2);  



% Estimation
%==========================================================================

load('estimation/5_supply_side/companies_k.mat');
companies_k = company_summary{k};

load('data/firm_size_distribution/prod_summary.mat');
prod_pct = prod_pct_summary{i_k};

cd 'E:\data_replication\robustness\firm_heterogeneity'

data_k_summary = [];
price_k_summary = [];
mc_k_summary = [];
share_k_summary = [];
m1_k_summary = [];
q_k_summary = [];
distance_k_summary = [];
population_log_k_summary = [];
FE_k_summary = [];
market_id_summary = [];
p_flag_summary = [];

for ttt = 1:T
    
%ttt = 1
ttt    

% Bring Variables to the market level and identify domestic firms
index_market = market_id_k == ttt;
declarant_market = declarant(index_market,:);
partner_market = partner(index_market,:);
pos_dom = find(declarant_market == partner_market);
index_d = declarant_market == partner_market;
numProdsTotal_market = sum(index_market, 1);
prodsMarket_market = numProdsTotal_market;

marketStarts_market = 1;
marketEnds_market = prodsMarket_market;
marketForProducts_market = ones(prodsMarket_market, 1);

% Keep only domestic firms
share_k_market = share_k(index_market,:);
price_k_market = price_k(index_market,:);
q_k_market = q_k(index_market,:);
data_k_market = data_k(index_market,:);
FE_k_market = FE_k(index_market,:);
population_log_k_market = population_log_k(index_market,:);
distance_k_market = distance_k(index_market,:);
mc_k_market = mc_k(index_market,:); 
m1_k_market = m1(index_market,:); 
m0_k_market = m0(index_market,:); 
N_k = size(price_k_market,1);
N_k_0 = sum(index_d);
f_market = f_k(index_market,:);
f_market = f_market(1,1);

% Compute Shares of remaining domestic firms
delta_market = q_k_market + log(price_k_market)*alpha_price + alpha_population*population_log_k_market + alpha_distance*distance_k_market;
Y_market = Y_k(ttt,:);
expmu_market = exp(log(price_k_market)*alpha_1*Y_market);
expmeanval_market = exp(delta_market);
oo = ones(1,nn);                  
sharesum_market = sparse(zeros(1,size(price_k_market,1)));
sharesum_market(1,:) = 1;
numer = (expmeanval_market*oo ).*expmu_market;        
sum1 = sharesum_market*numer;
sum11 = 1./(0+sum1);                      
denom1 = sum11(marketForProducts_market,:);    
simShare_true = numer.*denom1;              
EstShare_true = mean(simShare_true,2); 

sum_s = simShare_true;
sum_s2 = simShare_true.^2;
alpha_i = alpha_price + alpha_1*Y_market;
sum_alpha_s = alpha_i.*simShare_true;
sum_alpha_s2 = alpha_i.*(simShare_true.^2);

% Initialize market data prior entry
N_d = sum(index_d);
price_k_ttt = price_k_market(index_d,:);
mc_k_ttt = mc_k_market(index_d,:);
share_k_ttt = share_k_market(index_d,:);
m1_k_ttt = m1_k_market(index_d,:);
q_k_ttt = q_k_market(index_d,:);
data_k_ttt = data_k_market(index_d,:);
distance_k_ttt = distance_k_market(index_d,:);
population_log_ttt = population_log_k_market(index_d,:);
FE_k_ttt = FE_k_market(index_d,:);
market_id_ttt = ttt * ones(size(price_k_ttt,1), 1);
p_flag_ttt = 1 * ones(size(price_k_ttt,1), 1);

if sum(index_d) > 0                                                        % If there is a dom firm
    
index_entry = 0;

while index_entry == 0

N_d = N_d + 1;

population_log_d = population_log_k_market(end,1) * ones(N_d,1);
distance_d = distance_k_market(end,1) * ones(N_d,1);
FE_d = ones(N_d,1) * FE_k_market(end,:);
Y_d = Y_market;

% Pick Start Values

price_d = [price_k_market(index_d,:); ones(N_d - N_k_0, 1) * price_k_market(end,:)]; 
q_d = [q_k_market(index_d,:); ones(N_d - N_k_0, 1) * q_k_market(end,:)];  
m0_d = m0_k_market(index_d,:); 
m1_d_1 = m1_k_market(index_d,:); 
m1_d_1 = m1_d_1(1,1);
mc_e = 0.9 * price_d((N_k_0 + 1):end, 1);
prod_pct_d = prod_pct(1:N_d,1);


x0 = [price_d; q_d; mc_e];
x = x0;

options_fsolve = optimoptions('fsolve', 'Algorithm', 'levenberg-marquardt', 'MaxFunEvals', 50000, 'MaxIter', 5000, 'FunctionTolerance', 1e-8);
y_init = solve_mc_e(x, N_d, m1_d_1, m0_d, N_k_0, alpha_price, alpha_population, population_log_d, alpha_distance, distance_d, delta_market, index_d, alpha_1, Y_d, expmu_market, nn, Y_market, prod_pct_d);
jj = 0;
if (any(isnan(y_init)) || any(y_init == -Inf)) == 1
while (any(isnan(y_init)) || any(y_init == -Inf)) == 1 && jj < 30
jj = jj + 1; 
price_d = rand(1) * ones(N_d, 1) * price_k_market(end,1);
mc_e = rand(1) * price_d((N_k_0 + 1):end, 1);
x0 = [price_d; q_d; mc_e];
x = x0;  
y_init = solve_mc_e(x, N_d, m1_d_1, m0_d, N_k_0, alpha_price, alpha_population, population_log_d, alpha_distance, distance_d, delta_market, index_d, alpha_1, Y_d, expmu_market, nn, Y_market, prod_pct_d);
end
end
if jj ~= 30
solve_mc_e_handle = @(x)solve_mc_e(x, N_d, m1_d_1, m0_d, N_k_0, alpha_price, alpha_population, population_log_d, alpha_distance, distance_d, delta_market, index_d, alpha_1, Y_d, expmu_market, nn, Y_market, prod_pct_d);
[x_opt, f_opt, p_flag] = fsolve(solve_mc_e_handle, x0, options_fsolve);


if (p_flag == 0 || p_flag == -2 || p_flag == 3 || p_flag == 4 || isreal(f_opt) == 0)
i = 0;
p_flag_new = -2;
f_opt_new = f_opt;
while ((p_flag_new == 0 || p_flag_new == -2 || p_flag_new == 3 || p_flag_new == 4 || isreal(f_opt_new) == 0) && i < 30)
i = i + 1   
price_d = rand(1) * ones(N_d, 1) * price_k_market(end,1);
mc_e = rand(1) * price_d((N_k_0 + 1):end, 1);
x0 = [price_d; q_d; mc_e];
x = x0;
y_init = solve_mc_e(x, N_d, m1_d_1, m0_d, N_k_0, alpha_price, alpha_population, population_log_d, alpha_distance, distance_d, delta_market, index_d, alpha_1, Y_d, expmu_market, nn, Y_market, prod_pct_d);
if (any(isnan(y_init)) || any(y_init == -Inf)) == 1
while (any(isnan(y_init)) || any(y_init == -Inf)) == 1 
price_d = rand(1) * ones(N_d, 1) * price_k_market(end,1);
mc_e = rand(1) * price_d((N_k_0 + 1):end, 1);
x0 = [price_d; q_d; mc_e];
x = x0;  
y_init = solve_mc_e(x, N_d, m1_d_1, m0_d, N_k_0, alpha_price, alpha_population, population_log_d, alpha_distance, distance_d, delta_market, index_d, alpha_1, Y_d, expmu_market, nn, Y_market, prod_pct_d);
end
end
solve_mc_e_handle = @(x)solve_mc_e(x, N_d, m1_d_1, m0_d, N_k_0, alpha_price, alpha_population, population_log_d, alpha_distance, distance_d, delta_market, index_d, alpha_1, Y_d, expmu_market, nn, Y_market, prod_pct_d);
[x_opt_new, f_opt_new, p_flag_new] = fsolve(solve_mc_e_handle, x0, options_fsolve);
end

if p_flag_new == 1
    x_opt = x_opt_new;
    f_opt = f_opt_new;  
    p_flag = p_flag_new;
end    
end
    
x = x_opt;

price_d = x_opt(1:N_d, 1); 
q_d = x_opt((N_d + 1):(2*N_d), 1); 
mc_e = x_opt((2*N_d + 1):end, 1);

m1_d = ones(N_d, 1) * m1_d_1;
mc_d = exp(m0_d .* (q_d(1:N_k_0, 1) .^ m1_d(1:N_k_0, 1)));
mc_d = [mc_d; mc_e];

% Done until here

[share_d, profit_d] = get_share_e(x, N_d, m1_d_1, m0_d, N_k_0, alpha_price, alpha_population, population_log_d, alpha_distance, distance_d, delta_market, index_d, alpha_1, Y_d, expmu_market, nn, Y_market, prod_pct_d);
share_all = share_d;

index_entry = any(profit_d < f_market);

if index_entry == 0 && p_flag == 1 
    
price_k_ttt = price_d;
mc_k_ttt = mc_d;
share_k_ttt = share_all;
m1_k_ttt = m1_d;
q_k_ttt = q_d;

data_k_ttt = data_k_market(index_d, :);
data_d = repmat(data_k_ttt(end,:), N_d - N_k_0, 1);
data_k_ttt = [data_k_ttt; data_d];

distance_k_ttt = distance_d;
population_log_ttt = population_log_d;
FE_k_ttt = FE_d;

market_id_ttt = ttt * ones(size(price_k_ttt,1), 1);

p_flag_ttt = p_flag * ones(size(price_k_ttt,1), 1);

end

if p_flag ~= 1
    index_entry = 1;
    p_flag_ttt = p_flag * ones(size(price_k_ttt,1), 1);
end

if (N_d - N_k_0) > 99
    index_entry = 1;
end

end

if jj == 30
   index_entry = 1; 
   p_flag_ttt = (-2) * ones(size(price_k_ttt,1), 1);
end

end

else
 
N_d = 1;
price_k_ttt = 9999;
mc_k_ttt = 9999;
share_k_ttt = 9999;
m1_k_ttt = 9999;
q_k_ttt = 9999;
data_k_ttt = data_k(index_market,1:3);
data_k_ttt = [data_k_ttt(1,:), 9999];
distance_k_ttt = 9999;
population_log_ttt = 9999;
FE_k_ttt = [9999, 9999, 9999, 9999];
market_id_ttt = market_id_k(index_market,1);
market_id_ttt = market_id_ttt(1,1);
p_flag_ttt = 9999;

end
    
N_new_summary(ttt, 1) = N_d - N_k_0 - 1;
price_k_summary = [price_k_summary; price_k_ttt];
mc_k_summary = [mc_k_summary; mc_k_ttt];
share_k_summary = [share_k_summary; share_k_ttt];
m1_k_summary = [m1_k_summary; m1_k_ttt];
q_k_summary = [q_k_summary; q_k_ttt];
data_k_summary = [data_k_summary; data_k_ttt];
distance_k_summary = [distance_k_summary; distance_k_ttt];
population_log_k_summary = [population_log_k_summary; population_log_ttt];
FE_k_summary = [FE_k_summary; FE_k_ttt];
market_id_summary = [market_id_summary; market_id_ttt];
p_flag_summary = [p_flag_summary; p_flag_ttt];

clearvars -except i_k ttt price_k_summary mc_k_summary share_k_summary m1_k_summary q_k_summary distance_k_summary population_log_k_summary FE_k_summary market_id_summary p_flag_summary market_id_k declarant partner share_k price_k mc_k q_k data_k FE_k population_log_k distance_k m1 m0 f_k alpha_price alpha_population alpha_distance Y_k alpha_1 nn data_k_summary data_save prod_pct    

end

data_save_old = data_save;

filename_costs_matlab = sprintf('cost_estimates/costs_entry_%i_k.mat', i_k);  
data_save = array2table([market_id_summary, data_k_summary, price_k_summary, mc_k_summary, share_k_summary, m1_k_summary, q_k_summary, distance_k_summary, population_log_k_summary, p_flag_summary, FE_k_summary]);
data_save.Properties.VariableNames = {'market_id', 'Year', 'Quarter', 'Declarant', 'Partner', 'price_k', 'mc_opt', 'share_k', 'm1', 'q_k', 'distance_k', 'population_log_k', 'p_flag', 'FE_k1', 'FE_k2', 'FE_k3', 'FE_k4'};  
save(filename_costs_matlab, 'data_save');  

